{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed PyTorch/XLA Basics\n",
    "\n",
    "Beginning with PyTorch/XLA 2.0, Kaggle supports the new PJRT preview runtime on TPU VMs! For more information about PJRT, see the [PyTorch/XLA GitHub repository](https://github.com/pytorch/xla/blob/master/docs/pjrt.md).\n",
    "\n",
    "PyTorch/XLA is a package that lets PyTorch run on TPU devices. Kaggle provides a free v3-8 TPU VM. v3-8 TPUs have 8 logical devices: 4 TPU chips, each having 2 cores. This notebook shows how to run simple distributed operations on a TPU using the PJRT runtime. For more information about the Cloud TPU architecture, [see the official documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).\n",
    "\n",
    "At the time of writing Kaggle Notebooks on TPU VM are preinstalled with Python 3.10 and PT/XLA 2.1. See below for the exact versions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:01.251900Z",
     "iopub.status.busy": "2024-01-10T19:30:01.251567Z",
     "iopub.status.idle": "2024-01-10T19:30:01.378200Z",
     "shell.execute_reply": "2024-01-10T19:30:01.377121Z",
     "shell.execute_reply.started": "2024-01-10T19:30:01.251872Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python 3.10.13\n"
     ]
    }
   ],
   "source": [
    "!python --version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:01.380390Z",
     "iopub.status.busy": "2024-01-10T19:30:01.380122Z",
     "iopub.status.idle": "2024-01-10T19:30:22.624453Z",
     "shell.execute_reply": "2024-01-10T19:30:22.623753Z",
     "shell.execute_reply.started": "2024-01-10T19:30:01.380364Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.1.0+cu121'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:22.625785Z",
     "iopub.status.busy": "2024-01-10T19:30:22.625450Z",
     "iopub.status.idle": "2024-01-10T19:30:28.439813Z",
     "shell.execute_reply": "2024-01-10T19:30:28.439042Z",
     "shell.execute_reply.started": "2024-01-10T19:30:22.625759Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'2.1.0+libtpu'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch_xla\n",
    "torch_xla.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unlike JAX or TensorFlow, the convention in PyTorch is to start a separate child process per device to minimize the impact of Python's [Global Interpreter Lock](https://en.wikipedia.org/wiki/Global_interpreter_lock). In eager PyTorch, this means spawning one child process per GPU. For more information, see [PyTorch's distributed training documentation](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#comparison-between-dataparallel-and-distributeddataparallel).\n",
    "\n",
    "Due to architectural constraints, it is not possible for more than one process to access a TPU chip simultaneously. Because TPU v3 has two TensorCores cores per TPU chip, that means that each process must drive at least two TPU cores. By default, PyTorch/XLA will spawn 4 processes in total (one per chip), each having two threads (one per TensorCore). This is handled transparently by `xmp.spawn`, which mirrors `mp.spawn`. However, it is important to keep in mind that _all distributed workloads on a TPU v2 or v3 are multithreaded_. The function you pass to `spawn` should be thread-safe.\n",
    "\n",
    "TPU v4 has a different architecture, where each TPU chip is represented to PyTorch as a single device, so we spawn one process per device as expected.\n",
    "\n",
    "See the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) for an in-depth look at TPU architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "trusted": true
   },
   "outputs": [],
   "source": [
    "!printenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:28.574793Z",
     "iopub.status.busy": "2024-01-10T19:30:28.574526Z",
     "iopub.status.idle": "2024-01-10T19:30:28.580178Z",
     "shell.execute_reply": "2024-01-10T19:30:28.579554Z",
     "shell.execute_reply.started": "2024-01-10T19:30:28.574764Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Temporary hack: remove some TPU environment variables to support multiprocessing\n",
    "# These will be set later by xmp.spawn.\n",
    "\n",
    "import os\n",
    "os.environ.pop('TPU_PROCESS_ADDRESSES')\n",
    "os.environ.pop('CLOUD_TPU_TASK_ID')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:28.581455Z",
     "iopub.status.busy": "2024-01-10T19:30:28.581117Z",
     "iopub.status.idle": "2024-01-10T19:30:28.606274Z",
     "shell.execute_reply": "2024-01-10T19:30:28.605671Z",
     "shell.execute_reply.started": "2024-01-10T19:30:28.581429Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "import torch_xla.core.xla_model as xm\n",
    "import torch_xla.distributed.xla_multiprocessing as xmp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get the current process/thread's default XLA device, use `torch_xla.device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:28.607393Z",
     "iopub.status.busy": "2024-01-10T19:30:28.607138Z",
     "iopub.status.idle": "2024-01-10T19:30:28.664032Z",
     "shell.execute_reply": "2024-01-10T19:30:28.662583Z",
     "shell.execute_reply.started": "2024-01-10T19:30:28.607368Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "import multiprocessing as mp\n",
    "lock = mp.Manager().Lock()\n",
    "\n",
    "def print_device(i, lock):\n",
    "    device = torch_xla.device()\n",
    "    with lock:\n",
    "        print('process', i, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run a function on each TPU device, pass it to `xmp.spawn`. We'll use an `mp.Lock` to prevent `print` statements from overlapping between processes. This make the output clearer, but it is optional.\n",
    "\n",
    "Note: in interactive notebooks, you must use `start_method='fork'`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:28.666657Z",
     "iopub.status.busy": "2024-01-10T19:30:28.666339Z",
     "iopub.status.idle": "2024-01-10T19:30:33.218095Z",
     "shell.execute_reply": "2024-01-10T19:30:33.216950Z",
     "shell.execute_reply.started": "2024-01-10T19:30:28.666621Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "process 4 xla:0\n",
      "process 0 xla:0\n",
      "process 1 xla:1\n",
      "process 6 xla:0\n",
      "process 7 xla:1\n",
      "process 5 xla:1\n",
      "process 2 xla:0\n",
      "process 3 xla:1\n"
     ]
    }
   ],
   "source": [
    "xmp.spawn(print_device, args=(lock,), start_method='fork')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: ignore the errors from `oauth2_credentials.cc`. These will be fixed in a future release."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run `torch` operations on a TPU, pass the corresponding XLA device in as the `device` parameter. When you pass in an XLA device, the operation is added to a graph, which is executed lazily as needed. To force all devices to evaluate the current graph, call `torch_xla.sync()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:33.219878Z",
     "iopub.status.busy": "2024-01-10T19:30:33.219569Z",
     "iopub.status.idle": "2024-01-10T19:30:35.653084Z",
     "shell.execute_reply": "2024-01-10T19:30:35.651887Z",
     "shell.execute_reply.started": "2024-01-10T19:30:33.219847Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:0')\n",
      "7 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:1')\n",
      "0 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:0')\n",
      "2 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:0')\n",
      "3 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:1')\n",
      "1 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:1')\n",
      "4 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:0')\n",
      "5 tensor([[2., 2., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 2., 2.]], device='xla:1')\n"
     ]
    }
   ],
   "source": [
    "def add_ones(i, lock):\n",
    "    x = torch.ones((3, 3), device='xla')\n",
    "    y = x + x\n",
    "\n",
    "    # Run graph to compute `y` before printing\n",
    "    torch_xla.sync()\n",
    "\n",
    "    with lock:\n",
    "        print(i, y)\n",
    "\n",
    "xmp.spawn(add_ones, args=(lock,), start_method='fork')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To communicate tensors between TPU devices, use the collective communication operations in `xla_model`, such as `all_gather`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:35.656796Z",
     "iopub.status.busy": "2024-01-10T19:30:35.656377Z",
     "iopub.status.idle": "2024-01-10T19:30:38.314318Z",
     "shell.execute_reply": "2024-01-10T19:30:38.313118Z",
     "shell.execute_reply.started": "2024-01-10T19:30:35.656763Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7 tensor([7], device='xla:1')\n",
      "6 tensor([6], device='xla:0')\n",
      "4 tensor([4], device='xla:0')\n",
      "5 tensor([5], device='xla:1')\n",
      "3 tensor([3], device='xla:1')\n",
      "2 tensor([2], device='xla:0')\n",
      "1 tensor([1], device='xla:1')\n",
      "0 tensor([0], device='xla:0')\n",
      "7 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n",
      "6 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n",
      "5 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n",
      "4 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n",
      "3 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n",
      "2 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n",
      "0 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n",
      "1 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n"
     ]
    }
   ],
   "source": [
    "def gather_ids(i, lock):\n",
    "    # Create a tensor on each device with the device ID\n",
    "    t = torch.tensor([i], device='xla')\n",
    "    with lock:\n",
    "        print(i, t)\n",
    "\n",
    "    # Collect and concatenate the IDs\n",
    "    ts = xm.all_gather(t)\n",
    "    torch_xla.sync()\n",
    "    with lock:\n",
    "        print(i, ts)\n",
    "\n",
    "xmp.spawn(gather_ids, args=(lock,), start_method='fork')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch/XLA 2.0 also ships with experimental support for the `torch.distributed` using the `pjrt://` `init_method`, including `DistributedDataParallel`.\n",
    "\n",
    "Because replicas are run multithreaded, the distributed function must be thread-safe. However, the global RNG that `torch` uses for module initialization will give inconsistent results between replicas on TPU v3, since there will be multiple threads concurrently using it. To ensure consistent parameters, we recommend broadcasting model parameters from replica 0 to the other replicas using `pjrt.broadcast_master_param`. In practice, you may also load each replica's parameters from a common checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-10T19:30:38.315927Z",
     "iopub.status.busy": "2024-01-10T19:30:38.315653Z",
     "iopub.status.idle": "2024-01-10T19:30:43.491104Z",
     "shell.execute_reply": "2024-01-10T19:30:43.490078Z",
     "shell.execute_reply.started": "2024-01-10T19:30:38.315899Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Patching torch.distributed state to support multithreading.\n",
      "WARNING:root:torch.distributed support on TPU v2 and v3 is experimental and does not support torchrun.\n",
      "[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n",
      "[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n",
      "[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n",
      "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n",
      "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n",
      "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n",
      "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 [tensor(-0.0005, device='xla:1', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:1', grad_fn=<MeanBackward0>)]\n",
      "4 [tensor(-0.0005, device='xla:0', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:0', grad_fn=<MeanBackward0>)]\n",
      "6 [tensor(-0.0005, device='xla:0', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:0', grad_fn=<MeanBackward0>)]\n",
      "7 [tensor(-0.0005, device='xla:1', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:1', grad_fn=<MeanBackward0>)]\n",
      "0 [tensor(-0.0005, device='xla:0', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:0', grad_fn=<MeanBackward0>)]\n",
      "5 [tensor(-0.0005, device='xla:1', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:1', grad_fn=<MeanBackward0>)]\n",
      "3 [tensor(-0.0005, device='xla:1', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:1', grad_fn=<MeanBackward0>)]\n",
      "2 [tensor(-0.0005, device='xla:0', grad_fn=<MeanBackward0>), tensor(-0.0019, device='xla:0', grad_fn=<MeanBackward0>)]\n"
     ]
    }
   ],
   "source": [
    "import torch.distributed as dist\n",
    "import torch.nn as nn\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "import torch.optim as optim\n",
    "\n",
    "import torch_xla.distributed.xla_backend # Registers `xla://` init_method\n",
    "import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n",
    "\n",
    "def toy_model(index, lock):\n",
    "    device = torch_xla.device()\n",
    "    dist.init_process_group('xla', init_method='xla://')\n",
    "\n",
    "    # Initialize a basic toy model\n",
    "    torch.manual_seed(42)\n",
    "    model = nn.Linear(128, 10).to(device)\n",
    "\n",
    "    # Optional for TPU v4\n",
    "    xm.broadcast_master_param(model)\n",
    "\n",
    "    model = DDP(model, gradient_as_bucket_view=True)\n",
    "\n",
    "    loss_fn = nn.MSELoss()\n",
    "    optimizer = optim.SGD(model.parameters(), lr=.001)\n",
    "\n",
    "    for i in range(10):\n",
    "        # Generate random inputs and outputs on the XLA device\n",
    "        data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = loss_fn(output, target)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "        # Run the pending graph\n",
    "        torch_xla.sync()\n",
    "\n",
    "    with lock:\n",
    "        # Print mean parameters so we can confirm they're the same across replicas\n",
    "        print(index, [p.mean() for p in model.parameters()])\n",
    "\n",
    "xmp.spawn(toy_model, args=(lock,), start_method='fork')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a more in-depth look at PyTorch/XLA, see our [API guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
